The autoreload extension is already loaded. To reload it, use:
%reload_ext autoreload
The autoreload extension is already loaded. To reload it, use:
%reload_ext autoreload
from jax import config
import pickle
config.update("jax_enable_x64", True)
from ssm4epi.models.hospitalization import (
hospitalization_model,
dates,
unique_a,
h_by_age,
I_by_age,
)
from ssm4epi.models.util import from_consecutive_logits
from isssm.importance_sampling import mc_integration
from isssm.kalman import state_mode
from jax import vmap
from ssm4epi.patch import full_deps
import jax.numpy as jnp
from pyprojroot import here
import matplotlib.pyplot as plt
from isssm.estimation import initial_theta, mle_pgssm
from isssm.laplace_approximation import laplace_approximation as LA
from isssm.modified_efficient_importance_sampling import (
modified_efficient_importance_sampling as MEIS,
)
from isssm.importance_sampling import pgssm_importance_sampling, ess_pct
import jax.random as jrn
from tqdm.notebook import tqdmRun hospitalization model on all data until final NCH evaluation date.
from typing import NamedTuple
from jaxtyping import Float, Array, PRNGKeyArray
from scipy.optimize import OptimizeResult
from isssm.typing import GLSSMProposal
class HospitalizationResult(NamedTuple):
a_index: int
h: Float[Array, "..."]
I: Float[Array, "..."]
y: Float[Array, "np1 p"]
theta0: Float[Array, "5"]
theta0_result: OptimizeResult
theta_hat: Float[Array, "5"]
theta_hat_result: OptimizeResult
proposal: GLSSMProposal
key: PRNGKeyArray
ess_pct: Floatdef showcase_result(a_index: int, n_delay: int) -> HospitalizationResult:
i_start = int(jnp.where(dates == start_date)[0][0])
np1 = int(jnp.where(dates == end_date)[0][0]) - i_start
h = h_by_age[a_index][:, :n_delay]
I = I_by_age[a_index][i_start : i_start + np1]
unique_a[a_index]
key = jrn.PRNGKey(4234809124)
y = h[i_start : i_start + np1]
aux = (np1, n_delay, 2, I)
p0_hat = y[:7].sum() / I[:7].sum()
theta_manual = jnp.log(jnp.array([1**2, 0.1**2, 0.1**2, 0.1**2, p0_hat]))
theta0_result = initial_theta(y, hospitalization_model, theta_manual, aux, n_iter)
theta0 = theta0_result.x
key, subkey = jrn.split(key)
mle_result = None
theta_hat = None
# mle_result = mle_pgssm(y, hospitalization_model, theta0, aux, n_iter, N_mle, subkey)
# theta_hat = mle_result.x
fitted_model = hospitalization_model(theta0, aux)
proposal_la, _ = LA(y, fitted_model, n_iter)
key, subkey = jrn.split(key)
proposal_meis, _ = MEIS(
y, fitted_model, proposal_la.z, proposal_la.Omega, n_iter, N_meis, subkey
)
key, subkey = jrn.split(key)
proposal = proposal_meis
_, log_weights = pgssm_importance_sampling(
y, fitted_model, proposal.z, proposal.Omega, N_ess, subkey
)
result = HospitalizationResult(
a_index,
h,
I,
y,
theta0,
theta0_result,
theta_hat,
mle_result,
proposal,
key,
ess_pct(log_weights),
)
with open(
here()
/ "data"
/ "results"
/ "4_hospitalizations"
/ "showcase"
/ f"result_A{a_index}_D{n_delay}.pkl",
"wb",
) as f:
pickle.dump(result, f)
print(f"Finished results for a_index {a_index} and n_delay {n_delay}")
return resultfig, axs = plt.subplots(4, 2, figsize=(20, 10))
axs = axs.flatten()
for a_index in range(len(unique_a)):
h = h_by_age[a_index]
I = I_by_age[a_index]
unique_a[a_index]
i_start = int(jnp.where(dates == start_date)[0][0])
np1 = int(jnp.where(dates == end_date)[0][0]) - i_start
y = h[i_start : i_start + np1]
axs[a_index].plot(y)
axs[a_index].set_title(f"{unique_a[a_index]}")
axs[-1].plot(jnp.array(I_by_age).T[i_start : i_start + np1], label=unique_a)
axs[-1].legend()
plt.show()Finished results for a_index 0 and n_delay 5
Finished results for a_index 1 and n_delay 5
Finished results for a_index 2 and n_delay 7
Finished results for a_index 3 and n_delay 7
Finished results for a_index 4 and n_delay 8
Finished results for a_index 5 and n_delay 8
Finished results for a_index 6 and n_delay 8
from ssm4epi.models.hospitalization import visualize_model_fit
def visualize_result(result: HospitalizationResult):
i_start = int(jnp.where(dates == start_date)[0][0])
# I = result.I[i_start : i_start + np1]
key, subkey = jrn.split(result.key)
np1, n_delay = result.y.shape
aux = (result.y.shape[0], n_delay, 2, result.I)
model = hospitalization_model(result.theta0, aux)
samples, log_weights = pgssm_importance_sampling(
result.y,
model,
result.proposal.z,
result.proposal.Omega,
N_fit,
subkey,
)
visualize_model_fit(
samples, log_weights, model, i_start, np1, result.y, result.a_index
)
visualize_result(results[0])
from isssm.importance_sampling import prediction
key = jrn.PRNGKey(2342312412)
key, subkey_prediction = jrn.split(key)
def f_pred(x, s, y):
np1, n_delay = y.shape
p = jnp.exp(s[:, :1])
delay_dist_smooth = from_consecutive_logits(x[:, 2 : (2 + n_delay - 1)])
delay_dist_observed = from_consecutive_logits(s[:, 1:])
weekday_1 = jnp.exp(x[:, (2 + n_delay - 1)][:, None])
weekday_2 = jnp.exp(x[:, (2 + n_delay - 1 + 6)][:, None])
average_delay = jnp.sum(
delay_dist_smooth * jnp.arange(n_delay), axis=-1, keepdims=True
)
return jnp.concatenate(
(
p,
delay_dist_smooth,
delay_dist_observed,
weekday_1,
weekday_2,
average_delay,
),
axis=-1,
)def stacked_prediction(f, result: HospitalizationResult):
proposal = result.proposal
result.theta0
fitted_model = hospitalization_model(
result.theta0, (result.y.shape[0], result.y.shape[1], 2, result.I)
)
mean, sd, quantiles = prediction(
f,
result.y,
proposal,
fitted_model,
N_predict,
subkey_prediction,
percentiles_of_interest,
)
return jnp.vstack((mean[None], sd[None], quantiles))
for i, result in tqdm(enumerate(results)):
jnp.save(
here()
/ f"data/results/4_hospitalizations/showcase/predictions_age_{unique_a[i]}.npy",
stacked_prediction(f_pred, result),
)data/results/4_hospitalizations/showcase/hyperparams.texhyperparams_statement = f"""
To fit the model for all age groups, we use at most {n_iter} iterations for all occurences of the \\gls{{la}}, with a convergence threshold set to $10^{{-5}}$ relative difference in $z$ and $\\Omega$. We use the same method for \\gls{{eis}}, where we additionally use ${N_meis:,}$ samples to determine the optimal proposal, starting with the proposal given by th \\gls{{la}}.
For \\gls{{mle}}, we use ${N_mle:,}$ samples to determine the maximum likelihood estimate of the parameters, initializing at the initial guess given by \\Cref{{alg:mle}}.
To obtain prediction intervals of the states, signals and missing observations we use ${N_predict:,}$ samples. To estimate the \\gls{{ess}} we use ${N_ess:,}$ samples. Note that all sample sizes are given before constructing the three antithetic variables.
"""
# write to tex file
with open(
here() / "data/results/4_hospitalizations/showcase/hyperparams.tex", "w"
) as f:
f.write(hyperparams_statement)import pandas as pd
with open(here() / "tables" / "hospitalization_showcase_ess.tex", "w") as f:
table = pd.DataFrame(
{
"Age group": [f"A{a}" for a in unique_a],
"EF [\\%]": [f"{result.ess_pct:.0f}" for result in results],
"weeks of delay": [
n_delay for _, (_, n_delay) in delays_per_age.iterrows()
],
}
).to_latex(index=False)
f.write(table)